DTSA 5511 Introduction to Machine Learning: Deep Learning
Final Report: Using Deep Learning to Model Ocean Surface Elevation using Real World Data
Author
Affiliation
Andrew Simms
University of Colorado Boulder
Published
December 9, 2024
1 Problem Description
Ocean wave prediction is used for maritime safety, wave energy conversion, and coastal engineering applications. This research explores deep learning approaches for predicting ocean surface elevation time-series using historical buoy measurements. While traditional wave analysis relies on statistical parameters including significant wave height (H_{m_0}) and energy period (T_e), many applications could benefit from more accurate wave-by-wave predictions of surface elevation.
The need for accurate wave prediction is particularly evident in wave energy applications, where Ringwood (2020) highlights challenges in control system optimization that depend on reliable wave forecasts. Abdelkhalik et al. (2016) demonstrates how wave predictions enable real-time optimization of energy extraction, showing that accurate forecasting directly impacts system performance and efficiency.
This project addresses the fundamental need for accurate near real-time wave prediction by developing deep learning models to forecast three-dimensional surface elevation time-series, focusing on maintaining both prediction accuracy and computational efficiency through models trained on previously collected measurements.
1.1 Data Sources
The study utilizes surface elevation wave measurements from the Coastal Data Information Program (CDIP) focusing on two strategic United States locations, Kaneohe Bay, Hawaii and Nags Head North Carolina. These locations were chosen because they have many years of realtime measurement and the sites have significant seasonal variations in wave conditions.
Figure 1 shows the Kaneohe Bay, Hawaii buoy location (CDIP 225) This buoy is located within the U.S. Navy’s Wave Energy Test Site (WETS)(Coastal Data Information Program 2023a). This deep water deployment at 84m depth experiences a mixed wave climate with trade wind waves, North Pacific swell, and South Pacific swell. The site generally maintains consistent wave conditions due to trade wind dominance.
Make this Notebook Trusted to load map: File -> Trust Notebook
Figure 1: Map of CDIP Buoy 225 Location at the Wave Energy Test Site, Kaneohe Bay, Oahu, HI
Figure 2 shows the Nags Head, North Carolina buoy location (CDIP 243) (Coastal Data Information Program 2023b). This buoy is located near the Jennettes’s Pier Wave Energy Test Center. This site, at an intermediate water depth of 21m, experiences primarily wind-driven wave conditions. The wave climate is highly variable due to the influence of Cape Hatteras weather systems, with conditions ranging from calm seas to severe storm events including tropical cyclones.
Surface elevation measurements are collected using Datawell Waverider DWR-MkIII buoys, detailed in Datawell BV (2006). These specialized oceanographic instruments capture three-dimensional displacement measurements at a sampling frequency of 1.28 Hz, providing high-resolution data of vertical, northward, and eastward wave motion. All measurements undergo CDIP’s standardized quality control and processing pipeline before being archived and made available for analysis.
Figure 3: CDIP Buoy 204 Deployed in Lower Cook Inlet Alaska. Photo from AOOS
1.2.1 Directional Reference Frame
As illustrated in Figure 4, the buoy’s movement is tracked in a three-dimensional reference frame, measuring displacements in vertical (Z), east-west (X), and north-south (Y) directions. Our prediction task focuses on forecasting these three displacement components in real-time as measurement data streams from the buoy. This multivariate time series prediction approach uses historical measurements of all three displacement components to forecast their future values over specified time horizons, providing a comprehensive representation of wave motion at each location.
Code
import numpy as npimport matplotlib.pyplot as pltfrom mpl_toolkits.mplot3d import Axes3Dimport seaborn as snssns.set_theme()# Get a pleasing color palette# colors = sns.color_palette("husl", 3) # Using husl for distinct but harmonious colorscolors = sns.color_palette()x_color = colors[0]y_color = colors[1]z_color = colors[2]# Create figurefig = plt.figure(figsize=(10, 10))ax = fig.add_subplot(111, projection='3d')# Create sphereu = np.linspace(0, 2* np.pi, 100)v = np.linspace(0, np.pi, 100)x =0.4* np.outer(np.cos(u), np.sin(v))y =0.4* np.outer(np.sin(u), np.sin(v))z =0.4* np.outer(np.ones(np.size(u)), np.cos(v))# Plot semi-transparent sphereax.plot_surface(x, y, z, color='orange', alpha=0.3)# Plot axes through sphere centerlength =0.6ax.plot([-length, length], [0, 0], [0, 0], color=x_color, linewidth=2, label='X (East/West)')ax.plot([0, 0], [-length, length], [0, 0], color=y_color, linewidth=2, label='Y (True North/South)')ax.plot([0, 0], [0, 0], [-length, length], color=z_color, linewidth=2, label='Z (Vertical)')# Add arrows at the endsarrow_length =0.1# X axis arrowsax.quiver(length, 0, 0, arrow_length, 0, 0, color=x_color, arrow_length_ratio=0.3)ax.quiver(-length, 0, 0, -arrow_length, 0, 0, color=x_color, arrow_length_ratio=0.3)# Y axis arrowsax.quiver(0, length, 0, 0, arrow_length, 0, color=y_color, arrow_length_ratio=0.3)ax.quiver(0, -length, 0, 0, -arrow_length, 0, color=y_color, arrow_length_ratio=0.3)# Z axis arrowsax.quiver(0, 0, length, 0, 0, arrow_length, color=z_color, arrow_length_ratio=0.3)ax.quiver(0, 0, -length, 0, 0, -arrow_length, color=z_color, arrow_length_ratio=0.3)# Set equal aspect ratioax.set_box_aspect([1,1,1])# Set axis limitslimit =0.55ax.set_xlim([-limit, limit])ax.set_ylim([-limit, limit])ax.set_zlim([-limit, limit])# Add gridax.grid(True, alpha=0.3)# Add axis labels with matching colorsax.set_xlabel('East/West Displacement (X) [m]', color=x_color, weight='bold', fontsize=18)ax.set_ylabel('True North/South Displacement (Y) [m]', color=y_color, weight='bold', fontsize=18)ax.set_zlabel('Vertical Displacement (Z) [m]', color=z_color, weight='bold', fontsize=18)# Adjust view angleax.view_init(elev=20, azim=180-45)# Set background color to whiteax.set_facecolor('white')fig.patch.set_facecolor('white')plt.tight_layout()plt.show()
Figure 4: Directional Reference Frame of Datawell Waverider DWR-MkIII Buoy
1.3 Deep Learning Architecture Overview
This project leverages both Long Short-Term Memory (LSTM) networks and Transformer architectures to predict ocean surface elevation measurements. LSTMs, first introduced by Hochreiter and Schmidhuber (1997), have demonstrated success in temporal sequence learning through their ability to capture long-term dependencies. The Transformer architecture, developed by Vaswani et al. (2023), offers an alternative approach using self-attention mechanisms to process sequential data without recurrence.
Previous work in ocean wave forecasting has shown promise using neural network approaches. Mandal and Prabaharan (2006) demonstrated effective wave height prediction using recurrent neural networks, while Kumar, Savitha, and Mamun (2017) explored sequential learning algorithms for regional wave height forecasting. Building on these foundations, our approach implements both LSTM and Transformer models using the PyTorch framework by Ansel et al. (2024), and the LSTM Module described by Sak, Senior, and Beaufays (2014), allowing for direct comparison of their performance in predicting three-dimensional surface elevation time series.
The workflow in Figure 5 is a systematic deep learning workflow optimized for wave prediction modeling. Beginning with exploratory data analysis of CDIP buoy measurements, the data undergoes preprocessing including normalization and temporal windowing. The model development phase explores multiple neural network architectures in parallel, followed by rigorous evaluation and hyperparameter tuning. The final model undergoes cross-site validation to assess its generalization capabilities across different ocean environments.
2 Exploratory Data Analysis
2.1 Displacement Measurement
Ocean wave measurements from CDIP buoys utilize sophisticated sensor arrays including accelerometers, magnetometers, and GPS sensors to track three-dimensional wave motion Datawell BV (2006). These instruments output vertical displacement (commonly referred to as surface elevation), along with northward and eastward displacements. Figure Figure 6 illustrates these three displacement components using a 30-minute sample from CDIP 225, demonstrating the typical wave motion captured by a Datawell Waverider buoy.
Code
import pandas as pdimport seaborn as snsimport matplotlib.pyplot as pltsns.set_theme()df = pd.read_parquet("../data/a2_std_partition/station_number=0225/year=2017/month=11/day=11/hour=11/minute=00/data_20171111_1100.parquet")df['vert_displacement_meters'].plot(figsize=(9, 2), linewidth=0.85, xlabel="Time", ylabel="Vertical\nDisplacement [m]")plt.show()df['north_displacement_meters'].plot(figsize=(9, 2), linewidth=0.85, xlabel="Time", ylabel="North/South\nDisplacement [m]")plt.show()df['east_displacement_meters'].plot(figsize=(9, 2), linewidth=0.85, xlabel="Time", ylabel="East/West\nDisplacement [m]")plt.show()
(a) Vertical (Z) Displacement
(b) North/South (Y) Displacement
(c) East/West (X) Displacement
Figure 6: CDIP 225 30 minutes of Displacement - November 11, 2017 @ 11 am. Data from Coastal Data Information Program (2023a)
In Figure 7, a detailed view of approximately 90 seconds of the same time series reveals the fine-scale structure of wave motion. While the overall pattern exhibits typical wave behavior, the vertical displacement shows distinct non-sinusoidal characteristics, particularly during directional transitions. These abrupt changes in vertical motion highlight the complex, non-linear nature of real ocean waves compared to idealized wave forms.
Figure 7: CDIP 225 1.5 minutes of Displacement - November 11, 2017 @ 11 am. Data from Coastal Data Information Program (2023a)
2.2 Data Source and Download
The displacement data used in this study was obtained from CDIP’s THREDDS data server. Data for both locations - the Wave Energy Test Site (CDIP 225) and Nags Head (CDIP 243) - is hosted on CDIP’s archive server with standardized displacement time series documentation. The data is provided in NetCDF format.
Each data file contains three-dimensional displacement measurements sampled at 1.28 Hz, along with corresponding quality control flags. The raw NetCDF files were processed using xarray by Hoyer and Joseph (2017) for efficient handling of the multidimensional data. Timestamps were generated according to CDIP specifications, accounting for sampling rates and filter delays. The processed measurements were then consolidated into pandas DataFrames, developed by The pandas development team (n.d.), and stored in parquet format for efficient access during model development.
Listing 1: CDIP Download Implementation
from pathlib import Pathimport requestsimport xarray as xrimport numpy as npimport pandas as pdimport seaborn as snssns.set_theme()# NAGS HEAD, NC - 243# station_number = "243"# KANEOHE BAY, WETS, HI - 225station_number ="225"# 1.28 hz * 30SAMPLES_PER_HALF_HOUR =2304def get_cdip_displacement_df(station_number, dataset_number): fname =f"{station_number}p1_d{dataset_number}.nc" nc_path = Path(f"./data/00_raw/{fname}").resolve()print(f"Opening {nc_path} if it exists...")if nc_path.exists() isFalse: nc_url =f"https://thredds.cdip.ucsd.edu/thredds/fileServer/cdip/archive/{station_number}p1/{fname}"print("Downloading", nc_url)# Download the NetCDF file using requests response = requests.get(nc_url)withopen(nc_path, "wb") as f: f.write(response.content)# Open the downloaded NetCDF file with xarray ds = xr.open_dataset(nc_path)# Extract the relevant variables from the dataset xdisp = ds["xyzXDisplacement"] # North/South Displacement (X) ydisp = ds["xyzYDisplacement"] # East/West Displacement (Y) zdisp = ds["xyzZDisplacement"] # Vertical Displacement (Z) qc_flag = ds["xyzFlagPrimary"] # Quality control flag# For some reason all of these are missing one sample. So we remove the last section xdisp = xdisp[:-(SAMPLES_PER_HALF_HOUR)] ydisp = ydisp[:-(SAMPLES_PER_HALF_HOUR)] zdisp = zdisp[:-(SAMPLES_PER_HALF_HOUR)] qc_flag = qc_flag[:-(SAMPLES_PER_HALF_HOUR)] filter_delay = ds["xyzFilterDelay"].values start_time = ds["xyzStartTime"].values # Start time of buoy data collection sample_rate =float( ds["xyzSampleRate"].values ) # Sample rate of buoy data collection sample_rate =round(sample_rate, 2)print(f"Station Number: {station_number}, dataset_number: {dataset_number}, sample_rate: {sample_rate}" )print(f"Len xdisp: {len(xdisp)}, num 30 min sections = {(len(xdisp) +1) /2304}")print(f"Filter delay: {filter_delay}") sample_delta_t_seconds =1/ sample_rate sample_delta_t_nanoseconds = sample_delta_t_seconds *1e9 n_times =len(xdisp) start_time_ns = start_time.astype("int64") start_time_ns = start_time.astype("int64") # Convert start_time to nanoseconds# start_time_ns -= filter_delay * 1e9 time_increments = ( np.arange(n_times) * sample_delta_t_nanoseconds ) # Create an array of time increments times = start_time_ns + time_increments time = pd.to_datetime(times, unit="ns", origin="unix", utc=True) # type: ignore df = pd.DataFrame( {"north_displacement_meters": xdisp,"east_displacement_meters": ydisp,"vert_displacement_meters": zdisp,"qc_displacement": qc_flag, }, index=time, )return dfstation_number ="225"station_number ="243"df_1 = get_cdip_displacement_df(station_number, "01")df_2 = get_cdip_displacement_df(station_number, "02")df_3 = get_cdip_displacement_df(station_number, "03")df_4 = get_cdip_displacement_df(station_number, "04")# df_5 = get_cdip_displacement_df(station_number, "05")# df_all = pd.concat([df_1, df_2, df_3, df_4, df_5], axis="index")df_all = pd.concat([df_1, df_2, df_3, df_4], axis="index")# df_all = pd.concat([df_1, df_2, df_3], axis="index")df_all = df_all.sort_index()df_all.to_parquet(f"./data/a1_one_to_one_parquet/{station_number}_all.parquet")print(df_all.info())print(f"Successfully saved {station_number}_all.parquet!")
The CDIP download, shown in Listing 1, demonstrates the automated download and processing pipeline. This code handles the retrieval of NetCDF files, extraction of displacement measurements, timestamp generation, and data organization into a structured format suitable for analysis and model training. Quality control flags are preserved throughout the processing pipeline to ensure data integrity.
2.3 Available Data
This section examines the temporal extent and characteristics of the available wave measurements.
2.3.1 Duration
Code
first_225 = pd.read_parquet("../data/a2_std_partition/station_number=0225/year=2016/month=08/day=26/hour=22/minute=00/data_20160826_2200.parquet")first_225_timestamp = first_225.index[0]first_225_timestamplast_225 = pd.read_parquet("../data/a2_std_partition/station_number=0225/year=2024/month=09/day=11/hour=18/minute=30/data_20240911_1830.parquet")last_225_timestamp = last_225.index[-1]last_225_timestampfirst_243 = pd.read_parquet("../data/a2_std_partition/station_number=0243/year=2018/month=08/day=26/hour=15/minute=00/data_20180826_1500.parquet")first_243_timestamp = first_243.index[0]first_243_timestamplast_243 = pd.read_parquet("../data/a2_std_partition/station_number=0243/year=2023/month=07/day=12/hour=23/minute=30/data_20230712_2330.parquet")last_243_timestamp = last_243.index[-1]last_243_timestampfrom datetime import datetime# Create the datadata = {'station': ['225', '243'],'start_date': [ first_225_timestamp, first_243_timestamp, ],'end_date': [ last_225_timestamp, last_243_timestamp, ]}# Create DataFramedf = pd.DataFrame(data)# Calculate durationdf['duration'] = df['end_date'] - df['start_date']# Function to format duration in human readable formatdef format_duration(timedelta): years = timedelta.days //365 remaining_days = timedelta.days %365 months = remaining_days //30 days = remaining_days %30 parts = []if years >0: parts.append(f"{years}{'year'if years ==1else'years'}")if months >0: parts.append(f"{months}{'month'if months ==1else'months'}")if days >0: parts.append(f"{days}{'day'if days ==1else'days'}")return", ".join(parts)# Add human readable durationdf['duration_human'] = df['duration'].apply(format_duration)# Format datetime columns to be more readabledf['start_date'] = df['start_date'].dt.strftime('%Y-%m-%d %H:%M')df['end_date'] = df['end_date'].dt.strftime('%Y-%m-%d %H:%M')df = df.rename({'start_date': "Start Date [UTC]",'end_date': "End Date [UTC]",'duration_human': "Duration",}, axis="columns")
Code
df[['Start Date [UTC]', 'End Date [UTC]', 'Duration']]
Table 2: Temporal Details of Downloaded CDIP Data
Start Date [UTC]
End Date [UTC]
Duration
0
2016-08-26 22:00
2024-09-11 19:00
8 years, 17 days
1
2018-08-26 15:00
2023-07-13 00:00
4 years, 10 months, 21 days
Based the information in Table 2, the CDIP buoy datasets provide extensive coverage for both locations: WETS (CDIP 225) spans approximately 8 years (2016-2024), while Nags Head (CDIP 243) covers nearly 5 years (2018-2023). Both datasets contain three-dimensional displacement measurements at 1.28 Hz sampling frequency.
2.4 Displacement Statistics
For each location, we analyzed the statistical characteristics of the three-dimensional displacement measurements.
Code
import duckdbimport osdef calculate_column_stats(partition_path, column_name): con = duckdb.connect() con.execute("SET enable_progress_bar = false;") query =f""" SELECT '{column_name}' as column_name, COUNT({column_name}) as count, COUNT(DISTINCT {column_name}) as unique_count, SUM(CASE WHEN {column_name} IS NULL THEN 1 ELSE 0 END) as null_count, MIN({column_name}) as min_value, MAX({column_name}) as max_value, AVG({column_name}::DOUBLE) as mean, STDDEV({column_name}::DOUBLE) as std_dev, PERCENTILE_CONT(0.25) WITHIN GROUP (ORDER BY {column_name}::DOUBLE) as q1, PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY {column_name}::DOUBLE) as median, PERCENTILE_CONT(0.75) WITHIN GROUP (ORDER BY {column_name}::DOUBLE) as q3 FROM read_parquet('{partition_path}/**/*.parquet', hive_partitioning=true) WHERE {column_name} IS NOT NULL """ stats_df = con.execute(query).df() con.close()return stats_dfdef analyze_displacement_data(base_path, columns_to_analyze, station_numbers, output_path, overwrite=False):# Check if stats file already existsif os.path.exists(output_path) andnot overwrite:return pd.read_parquet(output_path) all_stats = []for station in station_numbers: station_str =f"{station:04d}"# Format station number with leading zeros partition_path =f"{base_path}/station_number={station_str}"ifnot os.path.exists(partition_path):print(f"Skipping station {station_str} - path does not exist")continueprint(f"Processing station {station_str}...")for column in columns_to_analyze:try: stats_df = calculate_column_stats(partition_path, column) stats_df['station'] = station_str all_stats.append(stats_df)print(f" Completed analysis of {column}")exceptExceptionas e:print(f" Error processing {column} for station {station_str}: {str(e)}")# Combine all resultsif all_stats: combined_stats = pd.concat(all_stats, ignore_index=True)# Create output directory if it doesn't exist os.makedirs(os.path.dirname(output_path), exist_ok=True)# Save to parquet combined_stats.to_parquet(output_path, index=False)print(f"\nStatistics saved to {output_path}")return combined_statselse:print("No statistics were generated")returnNone# Example usagebase_path ="../data/a2_std_partition"columns_to_analyze = ["vert_displacement_meters","north_displacement_meters","east_displacement_meters"]station_numbers = [225, 243] # Add more station numbers as neededoutput_path ="../data/displacement_stats.parquet"# Run the analysis - will load existing file if it existsstats_df = analyze_displacement_data( base_path=base_path, columns_to_analyze=columns_to_analyze, station_numbers=station_numbers, output_path=output_path, overwrite=False# Set to True to force recalculation)stats_df["Range [m]"] = stats_df['max_value'] + stats_df['min_value'].abs()# stats_df.loc[stats_df['column_name'] == 'vert_displacement_meters'] = "Vertical Displacement [m]"stats_df.loc[stats_df['column_name'] =='vert_displacement_meters', 'column_name'] ="Vertical Displacement [m]"stats_df.loc[stats_df['column_name'] =='north_displacement_meters', 'column_name'] ="North/South Displacement [m]"stats_df.loc[stats_df['column_name'] =='east_displacement_meters', 'column_name'] ="East/West Displacement [m]"stats_df.loc[stats_df['station'] =='0225', 'station'] ="225 - Kaneohe Bay, HI"stats_df.loc[stats_df['station'] =='0243', 'station'] ="243 - Nags Head, NC"# stats_df = stats_df.rename(columns={'vert_displacement_meters': 'Vertical Displacement [m]'})stats_df = stats_df.rename({"count": "Count","min_value": "Min [m]","max_value": "Max [m]","mean": "Mean [m]","std_dev": "Standard Deviation [m]",}, axis="columns")# stats_df
Figure 9 shows that at WETS (CDIP 225), vertical displacements range approximately \pm5\,\mathrm{m}, while north-south and east-west displacements show similar ranges of about \pm6\,\mathrm{m}. The Nags Head site (CDIP 243) experiences larger displacement ranges, with vertical motion reaching \pm13\,\mathrm{m} and horizontal displacements extending up to \pm52\,\mathrm{m}, reflecting its more dynamic wave climate. All displacement components at both locations show near-zero means with standard deviations between 0.33 and 0.45\,\mathrm{m}, indicating symmetric wave motion about the mean position. The anomalous 52\,\mathrm{m} range in horizontal displacement at Nags Head suggests potential outliers or measurement artifacts that should be filtered prior to model training to ensure data quality.
2.5 Partitioning Data
2.6 Data Partitioning
To facilitate efficient data handling and model development, we partitioned the continuous time series into 30-minute segments using a hierarchical storage structure. Each segment contains 2,304 samples (corresponding to the 1.28 Hz sampling rate) and is organized using a hive partitioning strategy based on temporal metadata (year, month, day, hour, minute) and station number.
The implementation, shown in Listing 2, creates a systematic file structure where each 30-minute measurement period is stored as an individual parquet file. This organization enables efficient data loading during model training and validation, while maintaining the temporal relationship between segments. The hierarchical structure also facilitates parallel processing and selective data loading based on specific time periods or stations.
Listing 2: Data Partitioning Implementation
def partition_df(this_df, station_number, output_folder):# Ensure output folder exists output_folder.mkdir(parents=True, exist_ok=True)# Sort the DataFrame by index (assuming timestamp index) this_df = this_df.sort_index()# Function to create partition pathdef create_partition_path(timestamp, station):return (f"station_number={station:04d}/"f"year={timestamp.year:04d}/"f"month={timestamp.month:02d}/"f"day={timestamp.day:02d}/"f"hour={timestamp.hour:02d}/"f"minute={timestamp.minute:02d}" )# Process data in chunks of 2304 samples chunk_size =2304 num_chunks =len(this_df) // chunk_sizefor i inrange(num_chunks): start_idx = i * chunk_size end_idx = (i +1) * chunk_size# Get chunk of data chunk_df = this_df.iloc[start_idx:end_idx]# Verify chunk duration chunk_duration = chunk_df.index[-1] - chunk_df.index[0] expected_duration = timedelta(minutes=30)# Use start time of chunk for partitioning chunk_start_time = chunk_df.index[0]# Create partition path partition_path = create_partition_path(chunk_start_time, station_number) full_path = output_folder / partition_path# Create directory structure full_path.mkdir(parents=True, exist_ok=True)# Save the partitioned data output_file = ( full_path /f"data_{chunk_start_time.strftime('%Y%m%d_%H%M')}.parquet" ) chunk_df.to_parquet(output_file)print(f"Saved partition: {partition_path}")print(f"Chunk {i} duration: {chunk_duration}")# Handle any remaining dataiflen(this_df) % chunk_size !=0: remaining_df = this_df.iloc[num_chunks * chunk_size :]print(f"Warning: {len(remaining_df)} samples remaining at end of file")print(f"Last timestamp: {remaining_df.index[-1]}")
2.7 Calculating Statistical Wave Parameters
The original multi year displacement data volume is too large to train a model for this project. As such we need to subset the problem. To do this we will use MHKiT-Python by Fao et al. (2024) to transform the surface elevations into 30 minute statistics of wave measurements. These are easier to visualize and understand the wave characteristics.
Transforming raw displacement measurements into 30-minute statistical wave parameters using MHKiT-Python Fao et al. (2024) provides a clearer view of the wave conditions at each site. These statistical metrics - significant wave height (H_{m_0}), energy period (T_e), and omnidirectional wave energy flux (J) - help identify unique wave conditions and temporal patterns within the large dataset.
The implementation shown in Listing 3 computes these wave parameters.
Listing 3: Calculation of Wave Quantities of Interest from Displacement
Figure 18: Omnidirectional Wave Energy Flux (Wave Power), J [W/m]
As shown in Figure 13 through Figure 15, WETS (CDIP 225) exhibits more periodic behavior while Nags Head (CDIP 243) shows greater variability. A notable data gap exists in the WETS measurements from 2021 to 2023. Focusing on 2019 (Figure 16 through Figure 18) highlights that Nags Head experiences higher peak wave power and wave heights but a narrower range of wave periods compared to WETS. These insights will guide our selection of representative data segments for model development.
2.8 Sea State Analysis
To understand the distribution of wave conditions at each site, we developed sea state matrices that bin the data by significant wave height (H_{m_0}) and energy period (T_e). This categorization quantifies the frequency of different wave conditions and helps ensure our training dataset encompasses a representative range of height and period states.
Code
import numpy as npdef plot_wave_heatmap(df, figsize=(12, 8)):# Create bins for Hm0 and Te hm0_bins = np.arange(0, df['significant_wave_height_meters'].max() +0.5, 0.5) te_bins = np.arange(0, df['energy_period_seconds'].max() +1, 1)# Use pd.cut to bin the data hm0_binned = pd.cut(df['significant_wave_height_meters'], bins=hm0_bins, labels=hm0_bins[:-1], include_lowest=True) te_binned = pd.cut(df['energy_period_seconds'], bins=te_bins, labels=te_bins[:-1], include_lowest=True)# Create cross-tabulation of binned data counts = pd.crosstab(hm0_binned, te_binned) counts = counts.sort_index(ascending=False)# Replace 0 counts with NaN counts = counts.replace(0, np.nan)# Create figure and axis plt.figure(figsize=figsize)# Create heatmap using seaborn ax = sns.heatmap( counts, cmap='viridis', annot=True, # Add count annotations fmt='.0f', # Format annotations as integers cbar_kws={'label': 'Count'}, )# Customize plot plt.xlabel('Energy Period Te (s)') plt.ylabel('Significant Wave Height Hm0 (m)')# Rotate x-axis labels for better readability# plt.xticks(rotation=45) plt.yticks(rotation=90)# Adjust layout to prevent label cutoff plt.tight_layout() plt.show()
Figure 19: Distribution of Sea States at WETS (CDIP 225), Showing Occurrence Count of Combined Significant Wave Height (H_{m_0}) and Energy Period (T_e) Conditions
Code
plot_wave_heatmap(qoi_243)
Figure 20: Distribution of Sea States at Nags Head (CDIP 243), Showing Occurrence Count of Combined Significant Wave Height (H_{m_0}) and Energy Period (T_e) Conditions
The sea state matrices reveal distinct wave climates at each location. As shown in Figure 19, the WETS site exhibits wave heights predominantly below 5 m, with energy periods clustered between 6-12 seconds. In contrast, Figure 20 shows that Nags Head experiences a broader range of conditions, with wave heights reaching 7.5 m and energy periods spanning 3-15 seconds. The presence of 3-second periods at Nags Head likely indicates local wind-generated waves, a characteristic absent at WETS.
These differing wave climate characteristics confirm that our dataset captures a diverse range of sea states across both locations, providing a robust foundation for model development. The comprehensive coverage of wave conditions suggests we can proceed with creating training datasets that will expose our models to the full spectrum of expected wave behaviors.
3 Data Cleaning and Dataset Generation
3.1 Structured Sampling Method
Based on the sea state distributions shown in Figure 19 and Figure 20, we developed a systematic sampling method to create balanced training datasets. The approach, implemented in Listing 4, samples time series segments from each combination of significant wave height and energy period bins, ensuring representation across the full range of observed wave conditions.
Listing 4: Function to Sample Sea State Matrix Bins
def sample_wave_bins(df, n_samples=1, hm0_step=0.5, te_step=1.0):# Create bins for Hm0 and Te hm0_bins = np.arange(0, df['significant_wave_height_meters'].max() + hm0_step, hm0_step) te_bins = np.arange(0, df['energy_period_seconds'].max() + te_step, te_step)# Add bin columns to the dataframe df_binned = df.copy() df_binned['hm0_bin'] = pd.cut(df['significant_wave_height_meters'], bins=hm0_bins, labels=hm0_bins[:-1], include_lowest=True) df_binned['te_bin'] = pd.cut(df['energy_period_seconds'], bins=te_bins, labels=te_bins[:-1], include_lowest=True)# Convert category types to float df_binned['hm0_bin'] = df_binned['hm0_bin'].astype(float) df_binned['te_bin'] = df_binned['te_bin'].astype(float)# Sample from each bin combination samples = []for hm0_val in df_binned['hm0_bin'].unique():for te_val in df_binned['te_bin'].unique(): bin_data = df_binned[ (df_binned['hm0_bin'] == hm0_val) & (df_binned['te_bin'] == te_val) ]ifnot bin_data.empty:# Sample min(n_samples, bin size) rows from this bin bin_samples = bin_data.sample( n=min(n_samples, len(bin_data)), random_state=42# For reproducibility ) samples.append(bin_samples)# Combine all samplesif samples: result = pd.concat(samples, axis=0).reset_index(drop=True)# Add bin center values for reference result['hm0_bin_center'] = result['hm0_bin'] + (hm0_step /2) result['te_bin_center'] = result['te_bin'] + (te_step /2) result.insert(0, 'station_number', result.pop('station_number'))return resultelse:return pd.DataFrame()
Initial sampling with one sample per bin yielded 76 half-hour segments for WETS (CDIP 225) and 112 segments for Nags Head (CDIP 243), providing a manageable dataset for initial model development. To explore the impact of dataset size on model performance, we also created expanded datasets with five samples per bin, resulting in 380 segments for WETS and 560 segments for Nags Head. While these larger datasets offer more comprehensive coverage of wave conditions, they require significantly more computational resources for training.
This sampling approach ensures our training data captures the diverse wave conditions present at each site while maintaining computational feasibility. The structured nature of the sampling helps prevent bias toward more common wave conditions, potentially improving model robustness across different sea states.
Additional datasets with five samples per bin were also created (353 segments for WETS and 493 for Nags Head) and archived for future research, though this project focuses on the more computationally manageable single-sample datasets.
4 Deep Learning Models
Building on the architectural overview presented in Section 1.3, we implemented three neural network models using PyTorch Ansel et al. (2024) to predict ocean wave displacements. Each model architecture was chosen and designed to capture different aspects of the temporal patterns present in wave motion.
4.1 LSTM Model
Our base LSTM model, shown in Listing 5, provides a straightforward approach to sequence prediction. The model processes three-dimensional displacement inputs through stacked LSTM layers with dropout regularization. This architecture enables the model to learn wave patterns at different time scales. The final linear layer maps the LSTM outputs back to displacement predictions, creating a direct sequence-to-sequence prediction framework.
The enhanced LSTM implementation, detailed in Listing 6, extends the base model with several architectural improvements. Bidirectional processing allows the model to consider both past and future context when making predictions. The addition of skip connections helps maintain gradient flow through the deep network, while layer normalization stabilizes training.
Listing 6: Enhanced Long Short-Term Memory PyTorch Model
class EnhancedLSTMModel(WavePredictionModel):def__init__(self, input_dim: int, hidden_dim: int=128, num_layers: int=2, dropout: float=0.2, learning_rate: float=1e-3, bidirectional: bool=True, ):super().__init__(input_dim, learning_rate)self.save_hyperparameters()self.hidden_dim = hidden_dimself.num_layers = num_layersself.bidirectional = bidirectional# Input processingself.input_layer = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(), nn.Dropout(dropout /2), )# Main LSTM layers with skip connectionsself.lstm_layers = nn.ModuleList()self.layer_norms = nn.ModuleList() lstm_input_dim = hidden_dim lstm_output_dim = hidden_dim *2if bidirectional else hidden_dimfor _ inrange(num_layers):self.lstm_layers.append( nn.LSTM( lstm_input_dim, hidden_dim, num_layers=1, batch_first=True, bidirectional=bidirectional, dropout=0, ) )self.layer_norms.append(nn.LayerNorm(lstm_output_dim)) lstm_input_dim = lstm_output_dim# Output processingself.output_layers = nn.ModuleList( [ nn.Linear(lstm_output_dim, hidden_dim), nn.Linear(hidden_dim, hidden_dim //2), nn.Linear(hidden_dim //2, input_dim), ] )self.dropouts = nn.ModuleList( [nn.Dropout(dropout) for _ inrange(len(self.output_layers))] )# Skip connectionself.skip_connection = nn.Linear(input_dim, input_dim)def forward(self, x: torch.Tensor) -> torch.Tensor:# Store original input for skip connection original_input = x# Input processing x =self.input_layer(x)# Process through LSTM layers with residual connectionsfor lstm, norm inzip(self.lstm_layers, self.layer_norms): residual = x x, _ = lstm(x) x = norm(x)if residual.shape == x.shape: x = x + residual# Output processingfor linear, dropout inzip(self.output_layers[:-1], self.dropouts[:-1]): residual = x x = linear(x) x = F.relu(x) x = dropout(x)if residual.shape == x.shape: x = x + residual# Final output layer x =self.output_layers[-1](x)# Add skip connection x = x +self.skip_connection(original_input)return xdef configure_optimizers(self):return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
4.3 Transformer Model
Our Transformer implementation, shown in Listing 7, takes a fundamentally different approach to sequence modeling. Rather than processing the wave motion sequentially, the model uses self-attention mechanisms to directly capture relationships between any two points in the input sequence. This architecture, combined with multi-head attention and position-wise feed-forward networks, should enable the model to identify both short-term wave patterns and longer-range dependencies in the displacement data.
We implemented the model training process using PyTorch Lightning Falcon and The PyTorch Lightning team (2019) to standardize the training, validation, and testing procedures. The Mean Absolute Error was selected as the loss function for evaluating prediction accuracy. Data was split using an 80-20 train-test split, with the training portion further divided 80-20 for validation.
All models were trained with consistent hyperparameters: learning rate of 0.001, Adam optimizer, 128-sample input and prediction windows, and 128 hidden dimensions. Training data consisted of one 30-minute dataset per sea state bin for each CDIP location. We tested several model configurations, including baseline LSTM (25 epochs), extended LSTM (100 epochs), varying LSTM layer depths (4 and 6 layers), and both basic and enhanced Transformer architectures. Model specifications are shown in Table 3.
Figure 27: Enhanced LSTM Model: Training vs. Validation Mean Absolute Error
5.2 Training Results Summary
The training results, shown in Figure 21 through Figure 27, indicate superior performance from the LSTM-based models. All LSTM variants demonstrated stable learning curves without significant overfitting. Notably, the 100-epoch LSTM model showed continued improvement in both training and validation loss, suggesting potential benefits from extended training periods. The Transformer models, while marginally functional, generally showed poor learning patterns compared to their LSTM counterparts.
6 Results
Model performance was evaluated using a test set comprising 20% of the sampled data. Each model generated predictions using 128-sample input windows, with results mapped back to the original time series for comparison with measured displacements.
def plot_test_section_compared_to_input(index, this_bins_df, this_source, this_targets, this_predictions, n_samples=128):# Calculate start and stop indices start = index * n_samples stop = start + n_samples# Get source path and load input data source_path = this_source.iloc[index]['Source Path'] input_df = pd.read_parquet(source_path)# Get statistics for the title stats = this_bins_df[this_bins_df["path"] == source_path]# Create figure fig, ax = plt.subplots(figsize=(16, 3))# Create index arrays for proper alignment input_index = np.arange(n_samples *2) shifted_index = np.arange(n_samples, 2* n_samples)# Plot input data with original index ax.plot(input_index, input_df['vert_displacement_meters'].iloc[:n_samples *2].values, linewidth=0.85, label="Input", alpha=0.7)# Plot target and prediction with shifted index ax.plot(shifted_index, this_targets['vert_displacement_meters'].iloc[start:stop].values, label="Target", linewidth=0.85) ax.plot(shifted_index, this_predictions['vert_displacement_meters'].iloc[start:stop].values, label="Prediction", linewidth=0.75)# Configure plot plt.ylabel("Vertical Displacement [m]") plt.legend(loc="upper right") plt.title(f"CDIP {stats['station_number'].item()} - $H_{{m0}}$: {stats['hm0_bin'].item()}, "f"$T_{{e}}$: {stats['te_bin'].item()}" ) plt.tight_layout() plt.show()def plot_test_section(index, this_bins_df, this_source, this_targets, this_predictions, n_samples=128):# Calculate start and stop indices start = index * n_samples stop = start + n_samples scale_factor =0.5 this_targets = this_targets.copy() this_predictions = this_predictions.copy() this_targets *= scale_factor this_predictions *= scale_factor# Get source path and load input data source_path = this_source.iloc[index]['Source Path']# Get statistics for the title stats = this_bins_df[this_bins_df["path"] == source_path]# Create figure fig, ax = plt.subplots(figsize=(16, 3)) ax.plot( this_targets['vert_displacement_meters'].iloc[start:stop].values, label="Target", linewidth=0.85, marker=".", markersize=4) ax.plot( this_predictions['vert_displacement_meters'].iloc[start:stop].values, label="Prediction", linewidth=0.75, marker=".", markersize=4)# Configure plot plt.ylabel("Vertical Displacement [m]") plt.legend(loc="upper right") plt.title(f"CDIP {stats['station_number'].item()} - $H_{{m_0}}$: {stats['hm0_bin'].item()}, "f"$T_{{e}}$: {stats['te_bin'].item()}", fontsize=18, ) plt.tight_layout() plt.show()
6.1 Select Visual Analysis
To examine model performance across different wave conditions, we present a series of representative time series comparisons. Each figure shows a 128-sample window of predictions alongside the corresponding measured wave displacements, with varying combinations of significant wave height (H_{m_0}) and energy period (T_e). These comparisons provide insight into how the models handle different wave states and reveal characteristic prediction patterns.
The time series comparisons, shown in Figure 28 through Figure 39, reveal varying prediction quality across different wave conditions. A common pattern emerges where predictions show larger errors at the start of each sequence but improve towards the end. This behavior suggests the models require several time steps to establish the wave state context before making accurate predictions. The LSTM models in particular demonstrate this adaptation, likely due to their ability to build and refine internal state representations as they process the sequence. Overall, the baseline model demonstrates capacity to capture underlying wave patterns, though prediction accuracy varies with wave state conditions.
Figure 40 through Figure 46 present the complete test set predictions against measured displacements, offering a comprehensive view of model performance across all wave conditions. The LSTM-based architectures demonstrate strong predictive capabilities, closely tracking the measured wave displacements across varying sea states. In particular, the baseline, enhanced, 4-layer, and 6-layer LSTM models show consistent prediction accuracy, suggesting successful learning of the underlying wave dynamics across diverse conditions.
The visual alignment between predictions and measurements for the LSTM models indicates their ability to capture both the frequency and amplitude characteristics of the wave motion. This comprehensive fit across different wave conditions supports the quantitative metrics and demonstrates the models’ generalization capabilities. In contrast, the Transformer models show notable deviation from the measured displacements, confirming their limited effectiveness for this prediction task.
6.3 Results Comparison
We evaluated model performance using three metrics: Mean Absolute Error (MAE), coefficient of determination (R^2), and Pearson’s correlation coefficient (\rho).
Code
results_df = pd.DataFrame(result_stats)
6.3.1 Mean Absolute Error
Code
results_df = results_df.sort_values(["mae", "station"])plt.figure(figsize=(6, 3.5))sns.barplot(results_df, y="label", x="mae", hue="station")for i in plt.gca().containers: plt.bar_label(i, fmt='%.2f', padding=3)plt.ylabel(None);plt.xlabel("Mean Absolute Error");
Figure 47: Mean Absolute Error Comparison by Model (Lower is better)
Code
plt.figure(figsize=(8, 4.0))sns.barplot(results_df, y="station", x="mae", hue="label")for i in plt.gca().containers: plt.bar_label(i, fmt='%.2f', padding=3)plt.ylabel(None);plt.xlabel("Mean Absolute Error");plt.legend(bbox_to_anchor=(1.25, 1.0), loc='upper center', ncol=1, frameon=False, title="Model")
Figure 48: Mean Absolute Error Comparison by Station (Lower is better)
6.3.2 Coefficient of Determination, R^2
Code
results_df = results_df.sort_values(["r2"], ascending=False)plt.figure(figsize=(6, 3.5))sns.barplot(results_df, y="label", x="r2", hue="station")for i in plt.gca().containers: plt.bar_label(i, fmt='%.2f', padding=3)plt.ylabel(None);plt.xlabel("$R^2$");
Figure 49: R^2 Comparison by Model
Code
plt.figure(figsize=(8, 4.0))sns.barplot(results_df, y="station", x="r2", hue="label")for i in plt.gca().containers: plt.bar_label(i, fmt='%.2f', padding=3)plt.ylabel(None);plt.xlabel("$R^2$");plt.legend(bbox_to_anchor=(1.25, 1.0), loc='upper center', ncol=1, frameon=False, title="Model")
Figure 50: R^2 Comparison by Station
6.3.3 Pearson’s Correlation [\rho]
Code
results_df = results_df.sort_values(["correlation"], ascending=False)plt.figure(figsize=(6, 3.5))sns.barplot(results_df, y="label", x="correlation", hue="station")for i in plt.gca().containers: plt.bar_label(i, fmt='%.2f', padding=3)plt.ylabel(None);plt.xlabel("Correlation");
Figure 51: Pearson’s Correlation Comparison by Model
Code
plt.figure(figsize=(8, 4.0))sns.barplot(results_df, y="station", x="correlation", hue="label")for i in plt.gca().containers: plt.bar_label(i, fmt='%.2f', padding=3)plt.ylabel(None);plt.xlabel("Correlation");plt.legend(bbox_to_anchor=(1.25, 1.0), loc='upper center', ncol=1, frameon=False, title="Model")
Figure 52: Pearson’s Correlation Comparison by Station
Code
results_df
Table 4: Results Summary
label
station
mae
r2
correlation
6
4 Layer LSTM
225
0.502155
0.889149
0.942949
7
4 Layer LSTM
243
0.874835
0.868795
0.932219
4
100 Epoch LSTM
225
0.582889
0.866286
0.930816
8
6 Layer LSTM
225
0.565404
0.856992
0.925749
9
6 Layer LSTM
243
0.946516
0.851715
0.922910
12
Enhanced LSTM
225
0.699822
0.837119
0.915022
0
Baseline
225
0.692493
0.829845
0.911056
5
100 Epoch LSTM
243
1.139891
0.807409
0.898738
13
Enhanced LSTM
243
1.223194
0.789425
0.888552
1
Baseline
243
1.235250
0.788793
0.888162
11
Enhanced Transformer
243
2.793854
0.004303
0.076359
2
Transformer
225
1.880350
0.002838
0.072718
10
Enhanced Transformer
225
1.932608
-0.047880
0.019221
3
Transformer
243
2.797447
-0.000079
0.005579
6.4 Results Summary
We evaluated model performance using three metrics: Mean Absolute Error (MAE), coefficient of determination (R^2), and Pearson’s correlation coefficient (\rho). As shown in Figure 47 through Figure 51, the LSTM-based models consistently outperformed the Transformer architectures across all metrics. The 4-layer LSTM achieved the best performance with MAE of 0.50\,\mathrm{m} and 0.87\,\mathrm{m} for WETS and Nags Head respectively, along with the highest R^2 values (0.89, 0.87) and correlation coefficients (0.94, 0.93). Increasing training epochs from 25 to 100 improved the baseline LSTM’s performance, particularly for WETS where the MAE decreased from 0.69\,\mathrm{m} to 0.58\,\mathrm{m}. However, further increasing model depth to 6 layers showed no additional benefit.
Both the basic and enhanced Transformer models performed poorly, with R^2 values near zero and correlation coefficients below 0.1, suggesting they failed to capture meaningful wave patterns. The enhanced LSTM showed no improvement over the baseline model, indicating that the additional architectural complexity did not benefit the wave prediction task.
Across all models, prediction accuracy was consistently better for WETS compared to Nags Head, as shown in Figure 48. This aligns with our earlier observation that WETS exhibits more regular wave patterns, while Nags Head experiences more variable conditions.
7 Conclusion
This project developed deep learning models for wave surface elevation prediction using data from two contrasting locations: the Wave Energy Test Site (WETS) in Hawaii and Nags Head, North Carolina. Through systematic data processing and structured sampling across sea states, we created representative training datasets capturing the range of wave conditions at each site. Multiple neural network architectures were implemented, trained, and tested using PyTorch, with performance evaluated through mean absolute error, coefficient of determination (R^2), and correlation metrics. Comprehensive testing across different sea states and visualization of predicted wave patterns provided quantitative and qualitative assessment of model capabilities. The results demonstrate the viability of neural networks for wave prediction, though with varying degrees of success across different architectures.
7.1 Key Findings
The LSTM-based models consistently outperformed Transformer architectures, with the 4-layer LSTM achieving the best results (\mathrm{MAE} = 0.50\,\mathrm{m} at WETS). Model performance showed sensitivity to architecture choices, where increasing complexity beyond four layers provided diminishing returns. Prediction accuracy was generally better at WETS than Nags Head, reflecting the more regular wave patterns at the Hawaiian site. While our models showed good prediction capabilities, some evidence of overfitting suggests room for improvement in model regularization.
7.2 Lessons Learned
The temporal nature of wave prediction requires careful consideration of model architecture and data preparation. LSTM networks proved particularly effective at capturing wave patterns, while Transformer models, despite their success in other sequential tasks, failed to capture meaningful wave dynamics. Increasing training epochs consistently improved performance, suggesting that extended training periods might yield further improvements. The structured sampling approach across sea states proved valuable for creating representative training datasets.
7.3 Future Work
Several promising directions for future research emerge from this work. Extension to north and east displacement predictions would provide a more complete wave motion model. Physics-informed neural networks could incorporate wave dynamics directly into the learning process. Alternative data sampling strategies and real-time validation would enhance model robustness. Additional coastal locations would test model generalization, while optimized scaling methods might improve prediction accuracy. These enhancements could lead to more reliable wave prediction systems for maritime applications.
References
Abdelkhalik, Ossama, Rush Robinett, Shangyan Zou, Giorgio Bacelli, Ryan Coe, Diana Bull, David Wilson, and Umesh Korde. 2016. “On the Control Design of Wave Energy Converters with Wave Prediction.”Journal of Ocean Engineering and Marine Energy 2 (4): 473–83. https://doi.org/10.1007/s40722-016-0048-4.
Ansel, Jason, Edward Yang, Horace He, Natalia Gimelshein, Animesh Jain, Michael Voznesensky, Bin Bao, et al. 2024. “PyTorch 2: Faster Machine Learning Through Dynamic Python Bytecode Transformation and Graph Compilation.” In 29th ACM International Conference on Architectural Support for Programming Languages and Operating Systems, Volume 2 (ASPLOS ’24). ACM. https://doi.org/10.1145/3620665.3640366.
Coastal Data Information Program. 2023a. “Station 225: Kaneohe Bay, WETS, HI - Wave, Sea Surface Temperature, and Ocean Current Time-Series Data.” Scripps Institution of Oceanography; USACE/PACIOOS; UC San Diego Library Digital Collections. https://doi.org/10.18437/C7WC72.
———. 2023b. “Station 243: Nags Head, NC - Wave, Sea Surface Temperature, and Ocean Current Time-Series Data.” Scripps Institution of Oceanography; USACE/CSI; UC San Diego Library Digital Collections. https://doi.org/10.18437/C7WC72.
Hoyer, Stephan, and Hamman Joseph. 2017. “xarray: N-D labeled Arrays and Datasets in Python.”Journal of Open Research Software 5 (1). https://doi.org/10.5334/jors.148.
Kumar, N. Krishna, R. Savitha, and Abdullah Al Mamun. 2017. “Regional Ocean Wave Height Prediction Using Sequential Learning Neural Networks.”Ocean Engineering 129: 605–12. https://doi.org/10.1016/j.oceaneng.2016.10.033.
Ringwood, John V. 2020. “Wave Energy Control: Status and Perspectives 2020 ⁎⁎This Paper Is Based Upon Work Supported by Science Foundation Ireland Under Grant No. 13/IA/1886 and Grant No. 12/RC/2302 for the Marine Renewable Ireland (MaREI) Centre.”IFAC-PapersOnLine 53 (2): 12271–82. https://doi.org/10.1016/j.ifacol.2020.12.1162.
Sak, Haşim, Andrew Senior, and Françoise Beaufays. 2014. “Long Short-Term Memory Based Recurrent Neural Network Architectures for Large Vocabulary Speech Recognition.”https://arxiv.org/abs/1402.1128.
Vaswani, Ashish, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin. 2023. “Attention Is All You Need.”https://arxiv.org/abs/1706.03762.